import argparse
import os
import json
from tqdm import tqdm
import numpy as np
import sys
print("Actual received parameters:", sys.argv)
import re
from symeval import EvaluatorMathBatch

symeval_evaluator = EvaluatorMathBatch()


# ============================================================
#                          setting
# ============================================================
def _parse_args():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--root_dir', type=str, help='', required=True)
    parser.add_argument('--gt_answer_key', type=str, default='answer')
    parser.add_argument('--eval_mode', type=str)
    parser.add_argument('--parse_pred_mode', type=str, default='parse_boxed')
    parser.add_argument('--id_key', type=str, default='id')
    args = parser.parse_args()
    return args



# ============================================================
#                           utils
# ============================================================
def load_jsonl(path):
    dataset = []
    with open(path, 'r') as f:
        for line in f:
            data = json.loads(line)
            dataset.append(data)
    return dataset

def load_json(path):
    with open(path, 'r') as f:
        data = json.load(f)
    return data

def save_jsonl(x, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'w') as file:
        for obj in x:
            json.dump(obj, file, ensure_ascii=False)
            file.write('\n')

def save_json(x, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'w') as f:
        json.dump(x, f, indent=4, ensure_ascii=False)
    print('saved to: ', save_path)
def file_hard(dataset):
    dataset2 = []
    for data in dataset:
        if int(data['level']) >= 5:
            dataset2.append(data)
    return dataset2


# ============================================================
#                            eval
# ============================================================

def eval(eval_mode, gt, pred):
    if eval_mode == 'symeval':
        gt = str(gt)
        score = symeval_evaluator.batch_eq(ref_answers=[gt], pred_answers=[pred])[0]
    else:
        raise NotImplementedError
    return score



# ============================================================
#                            parse
# ============================================================

def parse_answer_boxed(pred_str):
    ## check fail case-1
    if 'boxed' not in pred_str:
        return ""
    ## check fail case-2
    ans = pred_str.split("boxed")
    if len(ans) == 1:
        return ""
    ## check fail case-3
    ans = ans[-1]
    if len(ans) == 0:
        return ""
    ##
    try:
        if ans[0] == "{":
            stack = 1
            a = ""
            for c in ans[1:]:
                if c == "{":
                    stack += 1
                    a += c
                elif c == "}":
                    stack -= 1
                    if stack == 0:
                        break
                    a += c
                else:
                    a += c
        else:
            a = ans.split("$")[0].strip()
    except:
        return ""
    return a
def parse_answer(pred_str):
    match = re.search(r'####\s*(\d+)', pred_str)
    if match:
        return int(match.group(1))
    return None

def parse_pred_answer(parse_mode, pred_str):
    if parse_mode == 'parse_boxed':
        pred = parse_answer_boxed(pred_str)
    else:
        raise NotImplementedError
    return pred



# ============================================================
#                            main
# ============================================================
def main():
    args = _parse_args()

    ### load
    result_path = os.path.join(args.root_dir, 'output.jsonl')
    results = load_jsonl(result_path)
    num_total = len(results)
    correct_count = 0
    ### eval
    for i, res in enumerate(results):
        output = res['generation_info']['full_output']

        pred_answer = parse_pred_answer(args.parse_pred_mode, output)
        res['generation_info']['pred_answer'] = pred_answer

        if 'gsm8k' in args.root_dir:
            gt_answer = parse_answer(res['answer'])
        else:
            gt_answer = res[args.gt_answer_key]
        is_correct = eval(args.eval_mode, gt_answer, pred_answer)
        print(f'==> {i+1}/{num_total}  gt: {gt_answer}  pred: {pred_answer}  is_correct: {is_correct}')
        if is_correct:
            correct_count += 1
        res['generation_info']['pred_is_correct'] = is_correct


        if 'id' not in res:
            res['id'] = i + 1
    ### metric
    num_total = len(results)
    num_correct = 0
    correct_data_ids = []
    wrong_data_ids = []
    num_round_total = []
    each_data_eval_info = {}
    finish_reason_info = {}
    finish_reason_of_each_data_with_wrong_pred = {}
    finish_reason_correct = []
    finish_reason_wrong = []
    pred_num_token_list = []
    pred_num_token_list_correct = []
    pred_num_token_list_wrong = []


    for res in results:
        # import pdb; pdb.set_trace()
        num_round = res['generation_info']['num_round']
        num_round_total.append(num_round)
        pred = res['generation_info']['pred_answer']
        gt = res[args.gt_answer_key]
        data_id = res['id']
        finish_reason = res['generation_info']['finish_reason']
        #
        pred_num_tokens = res['generation_info']['full_output_token_number']
        pred_num_token_list.append(pred_num_tokens)

        #
        is_correct = res['generation_info']['pred_is_correct']
        if is_correct:
            num_correct += 1
            correct_data_ids.append(data_id)
            pred_num_token_list_correct.append(pred_num_tokens)
            finish_reason_correct.append(finish_reason)
        else:
            wrong_data_ids.append(data_id)
            pred_num_token_list_wrong.append(pred_num_tokens)
            finish_reason_wrong.append(finish_reason)

        #
        if finish_reason not in finish_reason_info:
            finish_reason_info[finish_reason] = []
        finish_reason_info[finish_reason].append(data_id)
        #
        each_data_eval_info[data_id] = f'[score: {is_correct}]  [gt: {gt}]  [pred: {pred}]  [pred_num_tokens: {pred_num_tokens}]  [finish_reason: {finish_reason}]'
        #
        if not is_correct:
            finish_reason_of_each_data_with_wrong_pred[data_id] = finish_reason


    avg_acc = (num_correct / num_total) * 100
    mean_pred_num_tokens = np.mean(pred_num_token_list)
    mean_pred_num_tokens_for_correct = np.mean(pred_num_token_list_correct)
    mean_pred_num_tokens_for_wrong = np.mean(pred_num_token_list_wrong)


    metrics = {
        'accuracy': avg_acc,
        'num_correct': num_correct,
        'num_total': num_total,
        'eval_mode': args.eval_mode,
        'parse_pred_mode': args.parse_pred_mode,
        #
        'correct_data_ids': correct_data_ids,
        'wrong_data_ids': wrong_data_ids,
        #
        'mean_output_num_tokens': mean_pred_num_tokens,
        'mean_output_num_tokens_for_data_with_correct_pred': mean_pred_num_tokens,
        'output_num_tokens': pred_num_token_list,
        'output_num_tokens_for_data_with_correct_pred': pred_num_token_list_correct,
        'output_num_tokens_for_data_with_wrong_pred': pred_num_token_list_wrong,
        #
        'each_data_eval_info': each_data_eval_info,
        #
        'finish_reason_of_each_data_with_wrong_pred': finish_reason_of_each_data_with_wrong_pred,
        'finish_reason_info': finish_reason_info,
        'finish_reason_for_data_with_correct_pred': finish_reason_correct,
        'finish_reason_for_data_with_wrong_pred': finish_reason_wrong,
        'output_num_tokens_for_data_with_correct_pred': pred_num_token_list_correct,
        'num_round_total': num_round_total,
    }

    print('\n\n--------------------- metric ---------------------')
    print('==> acc: ', avg_acc)
    print(f'\n==> correct / total: {num_correct} / {num_total}')

    print('\n==> data_ids: ')
    print('    correct: ', correct_data_ids)
    print('    wrong:   ', wrong_data_ids)

    print('\n==> output_token_number: ')
    print('    correct: ', pred_num_token_list_correct)
    print('    wrong:   ', pred_num_token_list_wrong)

    print('\n==> finish reason: ')
    print('    correct: ', finish_reason_correct)
    print('    wrong:   ', finish_reason_wrong)

    print('\n==> number of round: ')
    print('    round: ', num_round_total)

    print('\n')

    ## save
    save_path = os.path.join(args.root_dir, f'metric_{args.eval_mode}.json')
    save_json(metrics, save_path)



main()